/*=========================================================================
 *
 *  Copyright Insight Software Consortium
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *         http://www.apache.org/licenses/LICENSE-2.0.txt
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 *
 *=========================================================================*/
#ifndef __itkDiffeomorphicDemonsRegistrationWithMaskFilter_hxx
#define __itkDiffeomorphicDemonsRegistrationWithMaskFilter_hxx

#include "itkDiffeomorphicDemonsRegistrationWithMaskFilter.h"
#include "itkSmoothingRecursiveGaussianImageFilter.h"

namespace itk
{
/**
 * Default constructor
 */
template <class TFixedImage, class TMovingImage, class TDisplacementField>
DiffeomorphicDemonsRegistrationWithMaskFilter<TFixedImage, TMovingImage, TDisplacementField>
::DiffeomorphicDemonsRegistrationWithMaskFilter() :
  m_UseFirstOrderExp(false)
{
  typename DemonsRegistrationFunctionType::Pointer drfp;
  drfp = DemonsRegistrationFunctionType::New();

  this->SetDifferenceFunction( static_cast<FiniteDifferenceFunctionType *>(
                                 drfp.GetPointer() ) );

  m_Multiplier = MultiplyByConstantType::New();
  m_Multiplier->InPlaceOn();

  m_Exponentiator = FieldExponentiatorType::New();

  m_Warper = VectorWarperType::New();
  FieldInterpolatorPointer VectorInterpolator =
    FieldInterpolatorType::New();
  m_Warper->SetInterpolator(VectorInterpolator);

  m_Adder = AdderType::New();
  m_Adder->InPlaceOn();
}

/**
 * Checks whether the DifferenceFunction is of type DemonsRegistrationFunction.
 * It throws and exception, if it is not.
 */
template <class TFixedImage, class TMovingImage, class TDisplacementField>
typename DiffeomorphicDemonsRegistrationWithMaskFilter<TFixedImage, TMovingImage,
                                                       TDisplacementField>::DemonsRegistrationFunctionType
* DiffeomorphicDemonsRegistrationWithMaskFilter<TFixedImage, TMovingImage, TDisplacementField>
::DownCastDifferenceFunctionType()
  {
  DemonsRegistrationFunctionType *drfp =
    dynamic_cast<DemonsRegistrationFunctionType *>
    ( this->GetDifferenceFunction().GetPointer() );

  if( !drfp )
    {
    itkExceptionMacro(
      << "Could not cast difference function to SymmetricDemonsRegistrationFunction");
    }

  return drfp;
  }

/**
 * Checks whether the DifferenceFunction is of type DemonsRegistrationFunction.
 * It throws and exception, if it is not.
 */
template <class TFixedImage, class TMovingImage, class TDisplacementField>
const typename DiffeomorphicDemonsRegistrationWithMaskFilter<TFixedImage, TMovingImage,
                                                             TDisplacementField>::DemonsRegistrationFunctionType
* DiffeomorphicDemonsRegistrationWithMaskFilter<TFixedImage, TMovingImage, TDisplacementField>
::DownCastDifferenceFunctionType() const
  {
  const DemonsRegistrationFunctionType *drfp =
    dynamic_cast<const DemonsRegistrationFunctionType *>
    ( this->GetDifferenceFunction().GetPointer() );

  if( !drfp )
    {
    itkExceptionMacro(
      << "Could not cast difference function to SymmetricDemonsRegistrationFunction");
    }

  return drfp;
  }

/**
 * Set the function state values before each iteration
 */
template <class TFixedImage, class TMovingImage, class TDisplacementField>
void
DiffeomorphicDemonsRegistrationWithMaskFilter<TFixedImage, TMovingImage, TDisplacementField>
::InitializeIteration()
{
  // update variables in the equation object
  DemonsRegistrationFunctionType *f = this->DownCastDifferenceFunctionType();

#if (ITK_VERSION_MAJOR < 4)
  f->SetDeformationField( this->GetDeformationField() );
#else
  f->SetDisplacementField( this->GetDisplacementField() );
#endif

  // call the superclass  implementation ( initializes f )
  Superclass::InitializeIteration();
}

/*
 * Get the metric value from the difference function
 */
template <class TFixedImage, class TMovingImage, class TDisplacementField>
double
DiffeomorphicDemonsRegistrationWithMaskFilter<TFixedImage, TMovingImage, TDisplacementField>
::GetMetric() const
{
  const DemonsRegistrationFunctionType *drfp = this->DownCastDifferenceFunctionType();

  return drfp->GetMetric();
}

/**
 *  Get Intensity Difference Threshold
 */
template <class TFixedImage, class TMovingImage, class TDisplacementField>
double
DiffeomorphicDemonsRegistrationWithMaskFilter<TFixedImage, TMovingImage, TDisplacementField>
::GetIntensityDifferenceThreshold() const
{
  const DemonsRegistrationFunctionType *drfp = this->DownCastDifferenceFunctionType();

  return drfp->GetIntensityDifferenceThreshold();
}

/**
 *  Set Intensity Difference Threshold
 */
template <class TFixedImage, class TMovingImage, class TDisplacementField>
void
DiffeomorphicDemonsRegistrationWithMaskFilter<TFixedImage, TMovingImage, TDisplacementField>
::SetIntensityDifferenceThreshold(double threshold)
{
  DemonsRegistrationFunctionType *drfp = this->DownCastDifferenceFunctionType();

  drfp->SetIntensityDifferenceThreshold(threshold);
}

/**
 *  Get Maximum Update Step Length
 */
template <class TFixedImage, class TMovingImage, class TDisplacementField>
double
DiffeomorphicDemonsRegistrationWithMaskFilter<TFixedImage, TMovingImage, TDisplacementField>
::GetMaximumUpdateStepLength() const
{
  const DemonsRegistrationFunctionType *drfp = this->DownCastDifferenceFunctionType();

  return drfp->GetMaximumUpdateStepLength();
}

/**
 *  Set Maximum Update Step Length
 */
template <class TFixedImage, class TMovingImage, class TDisplacementField>
void
DiffeomorphicDemonsRegistrationWithMaskFilter<TFixedImage, TMovingImage, TDisplacementField>
::SetMaximumUpdateStepLength(double threshold)
{
  DemonsRegistrationFunctionType *drfp = this->DownCastDifferenceFunctionType();

  drfp->SetMaximumUpdateStepLength(threshold);
}

/**
 * Get the metric value from the difference function
 */
template <class TFixedImage, class TMovingImage, class TDisplacementField>
const double &
DiffeomorphicDemonsRegistrationWithMaskFilter<TFixedImage, TMovingImage, TDisplacementField>
::GetRMSChange() const
{
  const DemonsRegistrationFunctionType *drfp = this->DownCastDifferenceFunctionType();

  return drfp->GetRMSChange();
}

/**
 *
 */
template <class TFixedImage, class TMovingImage, class TDisplacementField>
typename DiffeomorphicDemonsRegistrationWithMaskFilter<TFixedImage, TMovingImage, TDisplacementField>
::GradientType
DiffeomorphicDemonsRegistrationWithMaskFilter<TFixedImage, TMovingImage, TDisplacementField>
::GetUseGradientType() const
{
  const DemonsRegistrationFunctionType *drfp = this->DownCastDifferenceFunctionType();

  return drfp->GetUseGradientType();
}

/**
 *
 */
template <class TFixedImage, class TMovingImage, class TDisplacementField>
void
DiffeomorphicDemonsRegistrationWithMaskFilter<TFixedImage, TMovingImage, TDisplacementField>
::SetUseGradientType(GradientType gtype)
{
  DemonsRegistrationFunctionType *drfp = this->DownCastDifferenceFunctionType();

  drfp->SetUseGradientType(gtype);
}

/**
 *
 */
template <class TFixedImage, class TMovingImage, class TDisplacementField>
void
DiffeomorphicDemonsRegistrationWithMaskFilter<TFixedImage, TMovingImage, TDisplacementField>
::AllocateUpdateBuffer()
{
  // The update buffer looks just like the output.
  DisplacementFieldPointer output = this->GetOutput();
  DisplacementFieldPointer upbuf = this->GetUpdateBuffer();

  upbuf->SetLargestPossibleRegion( output->GetLargestPossibleRegion() );
  upbuf->SetRequestedRegion( output->GetRequestedRegion() );
  upbuf->SetBufferedRegion( output->GetBufferedRegion() );
  upbuf->SetOrigin( output->GetOrigin() );
  upbuf->SetSpacing( output->GetSpacing() );
  upbuf->SetDirection( output->GetDirection() );
  upbuf->Allocate();
}

/**
 * Get the metric value from the difference function
 */
template <class TFixedImage, class TMovingImage, class TDisplacementField>
void
DiffeomorphicDemonsRegistrationWithMaskFilter<TFixedImage, TMovingImage, TDisplacementField>
#if (ITK_VERSION_MAJOR < 4)
// This is for meeting the virutal function signature for ITKv3 polymorphic heirarchy
::ApplyUpdate(TimeStepType dt)
#else
::ApplyUpdate(const TimeStepType &dt)
#endif
  {
  // If we smooth the update buffer before applying it, then the are
  // approximating a viscuous problem as opposed to an elastic problem
  if( this->GetSmoothUpdateField() )
    {
    this->SmoothUpdateField();
    }

  // Use time step if necessary. In many cases
  // the time step is one so this will be skipped
  if( vcl_fabs(dt - 1.0) > 1.0e-4 )
    {
    itkDebugMacro("Using timestep: " << dt);
#if (ITK_VERSION_MAJOR < 4)
    m_Multiplier->SetConstant(dt);
#else
    m_Multiplier->SetInput2(dt);
#endif
    m_Multiplier->SetInput( this->GetUpdateBuffer() );
    m_Multiplier->GraftOutput( this->GetUpdateBuffer() );
    // in place update
    m_Multiplier->Update();
    // graft output back to this->GetUpdateBuffer()
    this->GetUpdateBuffer()->Graft( m_Multiplier->GetOutput() );
    }

  if( this->m_UseFirstOrderExp )
    {
    // use s <- s o (Id +u)

    // skip exponential and compose the vector fields
    m_Warper->SetOutputOrigin( this->GetUpdateBuffer()->GetOrigin() );
    m_Warper->SetOutputSpacing( this->GetUpdateBuffer()->GetSpacing() );
    m_Warper->SetOutputDirection( this->GetUpdateBuffer()->GetDirection() );
    m_Warper->SetInput( this->GetOutput() );
#if (ITK_VERSION_MAJOR < 4)
    m_Warper->SetDeformationField( this->GetUpdateBuffer() );
#else
    m_Warper->SetDisplacementField( this->GetUpdateBuffer() );
#endif

    m_Adder->SetInput1( m_Warper->GetOutput() );
    m_Adder->SetInput2( this->GetUpdateBuffer() );

    m_Adder->GetOutput()->SetRequestedRegion(
      this->GetOutput()->GetRequestedRegion() );
    }
  else
    {
    // use s <- s o exp(u)

    // compute the exponential
    m_Exponentiator->SetInput( this->GetUpdateBuffer() );

    const double imposedMaxUpStep = this->GetMaximumUpdateStepLength();
    if( imposedMaxUpStep > 0.0 )
      {
      // max(norm(Phi))/2^N <= 0.25*pixelspacing
      const double numiterfloat = 2.0 + vcl_log(imposedMaxUpStep) / vnl_math::ln2;
      unsigned int numiter = 0;
      if( numiterfloat > 0.0 )
        {
        numiter = Math::Ceil<unsigned int>(numiterfloat);
        }

      m_Exponentiator->AutomaticNumberOfIterationsOff();
      m_Exponentiator->SetMaximumNumberOfIterations(numiter);
      }
    else
      {
      m_Exponentiator->AutomaticNumberOfIterationsOn();
      // just set a high value so that automatic number of step
      // is not thresholded
      m_Exponentiator->SetMaximumNumberOfIterations(2000u);
      }

    m_Exponentiator->GetOutput()->SetRequestedRegion(
      this->GetOutput()->GetRequestedRegion() );

    m_Exponentiator->Update();

    // compose the vector fields
    m_Warper->SetOutputOrigin( this->GetUpdateBuffer()->GetOrigin() );
    m_Warper->SetOutputSpacing( this->GetUpdateBuffer()->GetSpacing() );
    m_Warper->SetOutputDirection( this->GetUpdateBuffer()->GetDirection() );
    m_Warper->SetInput( this->GetOutput() );
#if (ITK_VERSION_MAJOR < 4)
    m_Warper->SetDeformationField( m_Exponentiator->GetOutput() );
#else
    m_Warper->SetDisplacementField( m_Exponentiator->GetOutput() );
#endif

    m_Warper->Update();

    m_Adder->SetInput1( m_Warper->GetOutput() );
    m_Adder->SetInput2( m_Exponentiator->GetOutput() );

    m_Adder->GetOutput()->SetRequestedRegion(
      this->GetOutput()->GetRequestedRegion() );
    }

  // Triggers update
  m_Adder->Update();

  // Region passing stuff
  this->GraftOutput( m_Adder->GetOutput() );
// HACK:  This is not in ITKv4  this->GetOutput()->Modified();

  DemonsRegistrationFunctionType *drfp = this->DownCastDifferenceFunctionType();

  this->SetRMSChange( drfp->GetRMSChange() );

  /**
   * Smooth the deformation field
   */
#if (ITK_VERSION_MAJOR < 4)
  if( this->GetSmoothDeformationField() )
    {
    this->SmoothDeformationField();
    }
#else
  if( this->GetSmoothDisplacementField() )
    {
    this->SmoothDisplacementField();
    }
#endif
  }

template <class TFixedImage, class TMovingImage, class TDisplacementField>
void
DiffeomorphicDemonsRegistrationWithMaskFilter<TFixedImage, TMovingImage, TDisplacementField>
::PrintSelf(std::ostream & os, Indent indent) const
{
  Superclass::PrintSelf(os, indent);

  os << indent << "Intensity difference threshold: "
     << this->GetIntensityDifferenceThreshold() << std::endl;
  os << indent << "Use First Order exponential: "
     << this->m_UseFirstOrderExp << std::endl;
}

template <class TFixedImage, class TMovingImage, class TDisplacementField>
void
DiffeomorphicDemonsRegistrationWithMaskFilter<TFixedImage, TMovingImage, TDisplacementField>
::VerifyInputInformation()
{
  // Do nothing, since images to be registered will not be in the same space
}

template <class TFixedImage, class TMovingImage, class TDisplacementField>
const typename DiffeomorphicDemonsRegistrationWithMaskFilter<TFixedImage, TMovingImage, TDisplacementField>::MaskType
* DiffeomorphicDemonsRegistrationWithMaskFilter<TFixedImage, TMovingImage, TDisplacementField>
::GetMovingImageMask() const
  {
  const DemonsRegistrationFunctionType *drfp = this->DownCastDifferenceFunctionType();
  return drfp->GetMovingImageMask();
  }

template <class TFixedImage, class TMovingImage, class TDisplacementField>
const typename DiffeomorphicDemonsRegistrationWithMaskFilter<TFixedImage, TMovingImage, TDisplacementField>::MaskType
* DiffeomorphicDemonsRegistrationWithMaskFilter<TFixedImage, TMovingImage, TDisplacementField>
::GetFixedImageMask() const
  {
  const DemonsRegistrationFunctionType *drfp = this->DownCastDifferenceFunctionType();

  return drfp->GetFixedImageMask();
  }

/**
 *
 */
template <class TFixedImage, class TMovingImage, class TDisplacementField>
void DiffeomorphicDemonsRegistrationWithMaskFilter<TFixedImage, TMovingImage, TDisplacementField>
::SetMovingImageMask(MaskType *mask)
{
  DemonsRegistrationFunctionType *drfp = this->DownCastDifferenceFunctionType();

  drfp->SetMovingImageMask(mask);
}

/**
 *
 */
template <class TFixedImage, class TMovingImage, class TDisplacementField>
void DiffeomorphicDemonsRegistrationWithMaskFilter<TFixedImage, TMovingImage, TDisplacementField>
::SetFixedImageMask(MaskType *mask)
{
  DemonsRegistrationFunctionType *drfp = this->DownCastDifferenceFunctionType();

  drfp->SetFixedImageMask(mask);
}
} // end namespace itk

#endif
